Skip to content

Add FunctionWrappers extension for differentiating through FunctionWrapper#2980

Open
ChrisRackauckas-Claude wants to merge 1 commit intoEnzymeAD:mainfrom
ChrisRackauckas-Claude:functionwrappers-ext
Open

Add FunctionWrappers extension for differentiating through FunctionWrapper#2980
ChrisRackauckas-Claude wants to merge 1 commit intoEnzymeAD:mainfrom
ChrisRackauckas-Claude:functionwrappers-ext

Conversation

@ChrisRackauckas-Claude
Copy link
Contributor

Summary

  • Adds EnzymeFunctionWrappersExt extension that defines forward and reverse mode EnzymeRules for FunctionWrapper{Ret,Args}
  • Extracts the original wrapped function via fw.obj[] and delegates to autodiff_deferred, bypassing the ccall barrier that Enzyme cannot differentiate through
  • Enables packages like NonlinearSolve.jl to use FunctionWrappers for norecompile infrastructure (AutoSpecializeCallable) without manual unwrapping at every Enzyme call site

Motivation

FunctionWrappers.jl wraps Julia functions behind C function pointers using ccall/llvmcall. Enzyme cannot differentiate through this mechanism, throwing EnzymeMutabilityException. NonlinearSolve.jl (PR #838) uses FunctionWrappers for its norecompile infrastructure and currently works around this by manually unwrapping at every call site — a fragile, whack-a-mole approach across 4+ files.

This extension solves the problem at the source: Enzyme automatically differentiates through the original wrapped function.

Implementation

Forward mode rule — single method handles both IIP (Nothing return) and OOP:

  • IIP (RT <: Const): Always runs autodiff_deferred(Forward, ...) to propagate tangents into argument shadow arrays
  • OOP: Uses Duplicated internally (since autodiff_deferred rejects DuplicatedNoNeed), with type assertions for return stability

Reverse mode rulesaugmented_primal + 3 reverse methods:

  • augmented_primal: Executes primal via unwrapped function, caches copies of overwritten args
  • IIP reverse: Reconstructs annotations with cached primals, delegates to autodiff_deferred(Reverse, ...)
  • OOP Active reverse: Delegates then scales per-arg gradients by dret.val using type-stable helper
  • OOP Duplicated/Const reverse: Delegates to autodiff_deferred for gradient accumulation

Design decisions:

  • Only handles func::Const{<:FunctionWrapper} (wrapper itself not differentiated) — covers NonlinearSolve and all standard uses
  • Batch mode (Width > 1) works automatically via BatchDuplicated annotations passed through to autodiff_deferred

Files changed

  • Project.toml — FunctionWrappers weakdep, extension, compat (1.1+), extras
  • ext/EnzymeFunctionWrappersExt.jl — New extension module (209 lines)
  • test/ext/functionwrappers.jl — 8 tests: IIP/OOP × Forward/Reverse, all verified against raw function
  • test/Project.toml — FunctionWrappers test dependency

Test plan

  • IIP forward mode: tangent propagation through FunctionWrapper matches raw function
  • IIP reverse mode: gradient accumulation through FunctionWrapper matches raw function
  • OOP forward mode: scalar derivative through FunctionWrapper matches raw function
  • OOP reverse mode: Active return gradient through FunctionWrapper matches raw function
  • CI passes

🤖 Generated with Claude Code

…apper

FunctionWrappers.jl wraps Julia functions behind C function pointers via
ccall, which Enzyme cannot differentiate through. This adds an
EnzymeFunctionWrappersExt extension that defines EnzymeRules for
FunctionWrapper, extracting the original wrapped function and delegating
to autodiff_deferred.

This enables packages like NonlinearSolve.jl to use FunctionWrappers for
their norecompile infrastructure without needing manual unwrapping at
every call site that might use Enzyme.

Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@github-actions
Copy link
Contributor

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/ext/EnzymeFunctionWrappersExt.jl b/ext/EnzymeFunctionWrappersExt.jl
index 04abc32e..e6f71fa7 100644
--- a/ext/EnzymeFunctionWrappersExt.jl
+++ b/ext/EnzymeFunctionWrappersExt.jl
@@ -11,10 +11,10 @@ using Enzyme
 # Helper to reconstruct an annotation with a cached primal value
 @inline _reconstruct_arg(arg::Const, cached, overwritten::Bool) = arg
 @inline function _reconstruct_arg(arg::Duplicated, cached, overwritten::Bool)
-    overwritten && cached !== nothing ? Duplicated(cached, arg.dval) : arg
+    return overwritten && cached !== nothing ? Duplicated(cached, arg.dval) : arg
 end
 @inline function _reconstruct_arg(arg::BatchDuplicated, cached, overwritten::Bool)
-    overwritten && cached !== nothing ? BatchDuplicated(cached, arg.dval) : arg
+    return overwritten && cached !== nothing ? BatchDuplicated(cached, arg.dval) : arg
 end
 @inline _reconstruct_arg(arg::Active, cached, overwritten::Bool) = arg
 
@@ -30,11 +30,11 @@ end
 # Single rule for both IIP (Nothing return) and OOP FunctionWrappers.
 # Extracts the wrapped function and delegates to autodiff_deferred.
 function EnzymeRules.forward(
-    config::EnzymeRules.FwdConfig,
-    func::Const{<:FunctionWrapper},
-    RT::Type{<:Annotation},
-    args::Annotation...,
-)
+        config::EnzymeRules.FwdConfig,
+        func::Const{<:FunctionWrapper},
+        RT::Type{<:Annotation},
+        args::Annotation...,
+    )
     raw_f = unwrap_fw(func.val)
 
     # For IIP functions (Const{Nothing} return), needs_shadow is false but we
@@ -52,13 +52,13 @@ function EnzymeRules.forward(
     # OOP: shadow is needed. Always use Duplicated for autodiff_deferred
     # (it rejects DuplicatedNoNeed).
     RealRt = eltype(RT)
-    if EnzymeRules.needs_primal(config)
+    return if EnzymeRules.needs_primal(config)
         res = Enzyme.autodiff_deferred(ForwardWithPrimal, Const(raw_f), Duplicated, args...)
         # autodiff ForwardWithPrimal returns (derivs, primal)
         if EnzymeRules.width(config) == 1
             return Duplicated(res[2]::RealRt, res[1]::RealRt)
         else
-            return BatchDuplicated(res[2]::RealRt, res[1]::NTuple{EnzymeRules.width(config),RealRt})
+            return BatchDuplicated(res[2]::RealRt, res[1]::NTuple{EnzymeRules.width(config), RealRt})
         end
     else
         res = Enzyme.autodiff_deferred(Forward, Const(raw_f), Duplicated, args...)
@@ -66,7 +66,7 @@ function EnzymeRules.forward(
         if EnzymeRules.width(config) == 1
             return res[1]::RealRt
         else
-            return res[1]::NTuple{EnzymeRules.width(config),RealRt}
+            return res[1]::NTuple{EnzymeRules.width(config), RealRt}
         end
     end
 end
@@ -77,11 +77,11 @@ end
 
 # augmented_primal: execute the forward pass, cache data for reverse
 function EnzymeRules.augmented_primal(
-    config::EnzymeRules.RevConfig,
-    func::Const{<:FunctionWrapper{Ret}},
-    RT::Type{<:Annotation},
-    args::Annotation...,
-) where {Ret}
+        config::EnzymeRules.RevConfig,
+        func::Const{<:FunctionWrapper{Ret}},
+        RT::Type{<:Annotation},
+        args::Annotation...,
+    ) where {Ret}
     raw_f = unwrap_fw(func.val)
     ow = EnzymeRules.overwritten(config)
     nargs = length(args)
@@ -129,12 +129,12 @@ end
 
 # reverse for IIP (Nothing return): accumulate gradients into dval arrays
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    func::Const{<:FunctionWrapper{Nothing}},
-    ::Type{<:Const{Nothing}},
-    tape,
-    args::Annotation...,
-)
+        config::EnzymeRules.RevConfig,
+        func::Const{<:FunctionWrapper{Nothing}},
+        ::Type{<:Const{Nothing}},
+        tape,
+        args::Annotation...,
+    )
     raw_f, cached_args = tape
     ow = EnzymeRules.overwritten(config)
     nargs = length(args)
@@ -154,12 +154,12 @@ end
 
 # reverse for OOP with Active return: return scaled per-arg gradients
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    func::Const{<:FunctionWrapper{Ret}},
-    dret::Active,
-    tape,
-    args::Annotation...,
-) where {Ret}
+        config::EnzymeRules.RevConfig,
+        func::Const{<:FunctionWrapper{Ret}},
+        dret::Active,
+        tape,
+        args::Annotation...,
+    ) where {Ret}
     raw_f, cached_args = tape
     ow = EnzymeRules.overwritten(config)
     nargs = length(args)
@@ -181,12 +181,12 @@ end
 
 # reverse for OOP with Duplicated/Const return type (non-Active)
 function EnzymeRules.reverse(
-    config::EnzymeRules.RevConfig,
-    func::Const{<:FunctionWrapper{Ret}},
-    dret::Type{<:Annotation},
-    tape,
-    args::Annotation...,
-) where {Ret}
+        config::EnzymeRules.RevConfig,
+        func::Const{<:FunctionWrapper{Ret}},
+        dret::Type{<:Annotation},
+        tape,
+        args::Annotation...,
+    ) where {Ret}
     if !(dret <: Const)
         raw_f, cached_args = tape
         ow = EnzymeRules.overwritten(config)
diff --git a/test/ext/functionwrappers.jl b/test/ext/functionwrappers.jl
index 94b4e9c8..79af62bf 100644
--- a/test/ext/functionwrappers.jl
+++ b/test/ext/functionwrappers.jl
@@ -10,19 +10,23 @@ using FunctionWrappers: FunctionWrapper
     f_oop(x, p) = p[1] * x^2
 
     @testset "IIP Forward Mode" begin
-        fw = FunctionWrapper{Nothing,Tuple{Vector{Float64},Vector{Float64},Vector{Float64}}}(f!)
+        fw = FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}(f!)
 
         u = [2.0]; du = zeros(1); p = [3.0]
         ddu = zeros(1); du_u = [1.0]
 
         # Differentiate through FunctionWrapper
-        Enzyme.autodiff(Forward, fw, Const{Nothing},
-            Duplicated(du, ddu), Duplicated(u, du_u), Const(p))
+        Enzyme.autodiff(
+            Forward, fw, Const{Nothing},
+            Duplicated(du, ddu), Duplicated(u, du_u), Const(p)
+        )
 
         # Compare with raw function
         u2 = [2.0]; du2 = zeros(1); ddu2 = zeros(1); du_u2 = [1.0]
-        Enzyme.autodiff(Forward, f!, Const{Nothing},
-            Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p))
+        Enzyme.autodiff(
+            Forward, f!, Const{Nothing},
+            Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p)
+        )
 
         @test ddu ≈ ddu2
         # ddu[1] should be d/du(p*u^2) * du_u = 3.0 * 2 * 2.0 * 1.0 = 12.0
@@ -30,18 +34,22 @@ using FunctionWrappers: FunctionWrapper
     end
 
     @testset "IIP Reverse Mode" begin
-        fw = FunctionWrapper{Nothing,Tuple{Vector{Float64},Vector{Float64},Vector{Float64}}}(f!)
+        fw = FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}(f!)
 
         u = [2.0]; du = zeros(1); p = [3.0]
         ddu = [1.0]; du_u = zeros(1)
 
-        Enzyme.autodiff(Reverse, fw, Const{Nothing},
-            Duplicated(du, ddu), Duplicated(u, du_u), Const(p))
+        Enzyme.autodiff(
+            Reverse, fw, Const{Nothing},
+            Duplicated(du, ddu), Duplicated(u, du_u), Const(p)
+        )
 
         # Compare with raw function
         u2 = [2.0]; du2 = zeros(1); ddu2 = [1.0]; du_u2 = zeros(1)
-        Enzyme.autodiff(Reverse, f!, Const{Nothing},
-            Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p))
+        Enzyme.autodiff(
+            Reverse, f!, Const{Nothing},
+            Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p)
+        )
 
         @test du_u ≈ du_u2
         # du/du[1] of (du[1] = p[1]*u[1]^2) with seed ddu[1]=1.0:
@@ -50,17 +58,21 @@ using FunctionWrappers: FunctionWrapper
     end
 
     @testset "OOP Forward Mode" begin
-        fw_oop = FunctionWrapper{Float64,Tuple{Float64,Vector{Float64}}}(f_oop)
+        fw_oop = FunctionWrapper{Float64, Tuple{Float64, Vector{Float64}}}(f_oop)
 
         x = 3.0; p = [2.0]
         dx = 1.0
 
-        res = Enzyme.autodiff(Forward, fw_oop, Duplicated,
-            Duplicated(x, dx), Const(p))
+        res = Enzyme.autodiff(
+            Forward, fw_oop, Duplicated,
+            Duplicated(x, dx), Const(p)
+        )
 
         # Compare with raw function
-        res2 = Enzyme.autodiff(Forward, f_oop, Duplicated,
-            Duplicated(x, dx), Const(p))
+        res2 = Enzyme.autodiff(
+            Forward, f_oop, Duplicated,
+            Duplicated(x, dx), Const(p)
+        )
 
         @test res[1] ≈ res2[1]
         # d/dx(p*x^2) = 2*p*x = 2*2.0*3.0 = 12.0
@@ -68,16 +80,20 @@ using FunctionWrappers: FunctionWrapper
     end
 
     @testset "OOP Reverse Mode" begin
-        fw_oop = FunctionWrapper{Float64,Tuple{Float64,Vector{Float64}}}(f_oop)
+        fw_oop = FunctionWrapper{Float64, Tuple{Float64, Vector{Float64}}}(f_oop)
 
         x = 3.0; p = [2.0]
 
-        res = Enzyme.autodiff(Reverse, fw_oop, Active,
-            Active(x), Const(p))
+        res = Enzyme.autodiff(
+            Reverse, fw_oop, Active,
+            Active(x), Const(p)
+        )
 
         # Compare with raw function
-        res2 = Enzyme.autodiff(Reverse, f_oop, Active,
-            Active(x), Const(p))
+        res2 = Enzyme.autodiff(
+            Reverse, f_oop, Active,
+            Active(x), Const(p)
+        )
 
         @test res[1][1] ≈ res2[1][1]
         # d/dx(p*x^2) = 2*p*x = 2*2.0*3.0 = 12.0

@ChrisRackauckas-Claude
Copy link
Contributor Author

Session State Summary (for continuation)

What this PR does

Adds EnzymeFunctionWrappersExt — a package extension that defines EnzymeRules for FunctionWrapper from FunctionWrappers.jl, enabling Enzyme to automatically differentiate through wrapped functions.

Problem: FunctionWrappers.jl wraps Julia functions behind C function pointers via ccall/llvmcall. Enzyme cannot differentiate through this mechanism, throwing EnzymeMutabilityException.

Solution: The extension extracts the original function via fw.obj[] and delegates to Enzyme.autodiff_deferred, which uses deferred codegen designed for nested AD.

Files created/modified

  • ext/EnzymeFunctionWrappersExt.jl — Extension module with forward, augmented_primal, and reverse rules
  • test/ext/functionwrappers.jl — 8 test cases (IIP/OOP × Forward/Reverse)
  • Project.toml — Added FunctionWrappers as weakdep, extension, compat entry

Test results (local, Julia 1.10)

All 8 tests pass:

  • IIP forward mode: tangents match raw function
  • IIP reverse mode: gradients match raw function
  • OOP forward mode: scalar return with Duplicated input
  • OOP reverse mode: Active return

Design decisions

  • Only handles func::Const{<:FunctionWrapper} (wrapper itself not differentiated) — covers all standard use cases
  • IIP (FunctionWrapper{Nothing,...}) and OOP (FunctionWrapper{Ret,...}) handled by separate method dispatches
  • Batch mode (Width > 1) works automatically via BatchDuplicated annotations passed through to autodiff_deferred

Downstream dependency

SciML/NonlinearSolve.jl PR #838 depends on this PR. NonlinearSolve currently has ~73 lines of manual Enzyme workaround code (_uses_enzyme_ad(), maybe_unwrap_prob_for_enzyme()) that can be removed once this extension is available. See: SciML/NonlinearSolve.jl#838 (comment)

CI status

CI workflows show action_required — needs a maintainer to approve the workflow run for this branch.

Local repo

  • Path: /home/crackauc/sandbox/tmp_20260218_173108_96644/Enzyme.jl
  • Branch: functionwrappers-ext
  • Remote origin: ChrisRackauckas-Claude/Enzyme.jl
  • Remote upstream: EnzymeAD/Enzyme.jl

Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants